-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stackless fixes #37
base: main
Are you sure you want to change the base?
stackless fixes #37
Conversation
`pytest tests` passes with these changes
@patrick-kidger @nstarman I forgot about this until just now... :P Does this seem like it's worth merging? |
@@ -141,76 +141,51 @@ def _wrap_if_array(x: Union[ArrayLike, "Value"]) -> "Value": | |||
|
|||
|
|||
class _QuaxTrace(core.Trace[_QuaxTracer]): | |||
def pure(self, val: ArrayLike) -> _QuaxTracer: | |||
if _is_value(val): | |||
raise TypeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's important to keep this behavior, can you give me a test that exercises it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing! The following should raise an error:
import quax
import jax
x = jax.numpy.arange(4.).reshape(2, 2)
key = jax.random.key(0)
y = quax.lora.LoraArray(x, rank=1, key=key)
def f(x):
return jax.lax.add_p.bind(x, y)
quax.quaxify(f)(y)
Basically this is just a check that all of our Quax values are properly wrapped into tracers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay! Finding time to get back to this now :D
IIUC from what I'm reading I think this largely LGTM!
def to_value(self, val): | ||
if isinstance(val, _QuaxTracer) and val._trace.tag is self.tag: | ||
return val.value | ||
else: | ||
return _DenseArrayValue(val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding of how new-JAX works, is it expected that we ever hit the else
branch? Old-JAX had it such that the tracers
of process_primitive(..., tracers, ...)
was guaranteed to be tracers from the current trace, which would imply always hitting the if
statement here.
values = tuple( | ||
x.array if isinstance(x, _DenseArrayValue) else x for x in values | ||
) | ||
try: | ||
rule = _rules[primitive] | ||
except KeyError: | ||
out = _default_process(primitive, values, params) | ||
with core.set_current_trace(self.parent_trace): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha! It's super interesting to see that new-JAX now has each trace record the parent in a stack like this. I've done a couple of hobby projects reimplementing various kinds of JAXlike designs (sadly none of them released yet), and I've always ended up going for something similar.
(Albeit I've tended to go the opposite way and do full-data-dependence rather than full-dynamic-context dependence, but I'm guessing you're constrained there by the desire to do omnistaging.)
else: | ||
out = method(*values, **params) | ||
with core.set_current_trace(self.parent_trace): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think you could hoist each core.set_current_trace
out of all these try/except blocks and have just one wrapping the whole thing.
@mattjj, if there's any way I can help on this PR, LMK :) |
From jax-ml/jax#25372, this is an attempt at fixing up quax in light of the JAX core rewrite in jax-ml/jax@c36e1f7 (aka "stackless").
Discussion: jax-ml/jax#25372